{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ee093335",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Reconstructing Point Cloud with DMTet\n",
    "\n",
    "Deep Marching Tetrahedra (DMTet) is a hybrid 3D representation that combines both implicit and explicit 3D surface representations. It represents a shape with a discrete SDF defined on vertices of a deformable tetrahedral grid. The SDF is converted to triangular mesh using a differentiable marching tetrahedra layer (MT), allowing explicit supervision on the extracted surface to be back-propagated to SDF and change mesh topology. In this tutorial, we demonstrate this by optimizing DMTet to reconstruct point cloud by minimizing the Chamfer Distance. The key functions used in this tutorial are in `kaolin.ops.conversions.trianglemesh`. See detailed [API documentation](https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.conversions.html#kaolin-ops-conversions)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "31d9198f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import kaolin\n",
    "import numpy as np\n",
    "from dmtet_network import Decoder\n",
    "\n",
    "# path to the point cloud to be reconstructed\n",
    "pcd_path = \"../samples/bear_pointcloud.usd\"\n",
    "# path to the output logs (readable with the training visualizer in the omniverse app)\n",
    "logs_path = './logs/'\n",
    "\n",
    "# We initialize the timelapse that will store USD for the visualization apps\n",
    "timelapse = kaolin.visualize.Timelapse(logs_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "58c9c196",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# arguments and hyperparameters\n",
    "device = 'cuda'\n",
    "lr = 1e-3\n",
    "laplacian_weight = 0.1\n",
    "iterations = 5000\n",
    "save_every = 100\n",
    "multires = 2\n",
    "grid_res = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16a2899d",
   "metadata": {},
   "source": [
    "# Loading Point Cloud\n",
    "\n",
    "In this example, we use the point cloud generated by [Omniverse Kaolin App](https://docs.omniverse.nvidia.com/app_kaolin/app_kaolin/user_manual.html#data-generator). We load the pre-generated point cloud in `examples/samples/` and normalize it to the range of the tetrahedral grid. The normalized point cloud is saved to the checkpoint which can be visualized using [the Omniverse app](https://docs.omniverse.nvidia.com/app_kaolin/app_kaolin).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5674d9a2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "torch.Size([89164, 3])\n"
     ]
    }
   ],
   "source": [
    "points = kaolin.io.usd.import_pointclouds(pcd_path)[0].points.to(device)\n",
    "if points.shape[0] > 100000:\n",
    "    idx = list(range(points.shape[0]))\n",
    "    np.random.shuffle(idx)\n",
    "    idx = torch.tensor(idx[:100000], device=points.device, dtype=torch.long)    \n",
    "    points = points[idx]\n",
    "\n",
    "# The reconstructed object needs to be slightly smaller than the grid to get watertight surface after MT.\n",
    "points = kaolin.ops.pointcloud.center_points(points.unsqueeze(0), normalize=True).squeeze(0) * 0.9\n",
    "timelapse.add_pointcloud_batch(category='input',\n",
    "                               pointcloud_list=[points.cpu()], points_type = \"usd_geom_points\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b39b36c",
   "metadata": {},
   "source": [
    "# Loading the Tetrahedral Grid\n",
    "\n",
    "DMTet starts from a uniform tetrahedral grid of predefined resolution, and uses a network to predict the SDF value as well as deviation vector at each grid vertex. \n",
    "\n",
    "Here we load the pre-generated tetrahedral grid using [Quartet](https://github.com/crawforddoran/quartet) at resolution 128, which has roughly the same number of vertices as a voxel grid of resolution 65. We use a simple MLP + positional encoding to predict the SDF and deviation vectors in DMTet, and initialize the encoded SDF to represent a sphere. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "33ab4b6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([277410, 3]) torch.Size([1524684, 4])\n",
      "Initialize SDF to sphere\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 279.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-trained MLP 5.480436811922118e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "tet_verts = torch.tensor(np.load('../samples/{}_verts.npz'.format(grid_res))['data'], dtype=torch.float, device=device)\n",
    "tets = torch.tensor(([np.load('../samples/{}_tets_{}.npz'.format(grid_res, i))['data'] for i in range(4)]), dtype=torch.long, device=device).permute(1,0)\n",
    "print (tet_verts.shape, tets.shape)\n",
    "\n",
    "# Initialize model and create optimizer\n",
    "model = Decoder(multires=multires).to(device)\n",
    "model.pre_train_sphere(1000)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73fe95a7",
   "metadata": {},
   "source": [
    "# Preparing the Losses and Regularizer\n",
    "\n",
    "During training we will use two losses defined on the surface mesh:\n",
    "- We use Chamfer Distance as the reconstruction loss. At each step, we randomly sample points from the surface mesh and compute the point-to-point distance to the GT point cloud.\n",
    "- DMTet can employ direct regularization on the surface mesh to impose useful geometric constraints. We demonstrate this with a Laplacian loss which encourages the surface to be smooth.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "78ad11ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).\n",
    "# https://mgarland.org/class/geom04/material/smoothing.pdf\n",
    "def laplace_regularizer_const(mesh_verts, mesh_faces):\n",
    "    term = torch.zeros_like(mesh_verts)\n",
    "    norm = torch.zeros_like(mesh_verts[..., 0:1])\n",
    "\n",
    "    v0 = mesh_verts[mesh_faces[:, 0], :]\n",
    "    v1 = mesh_verts[mesh_faces[:, 1], :]\n",
    "    v2 = mesh_verts[mesh_faces[:, 2], :]\n",
    "\n",
    "    term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))\n",
    "    term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))\n",
    "    term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))\n",
    "\n",
    "    two = torch.ones_like(v0) * 2.0\n",
    "    norm.scatter_add_(0, mesh_faces[:, 0:1], two)\n",
    "    norm.scatter_add_(0, mesh_faces[:, 1:2], two)\n",
    "    norm.scatter_add_(0, mesh_faces[:, 2:3], two)\n",
    "\n",
    "    term = term / torch.clamp(norm, min=1.0)\n",
    "\n",
    "    return torch.mean(term**2)\n",
    "\n",
    "def loss_f(mesh_verts, mesh_faces, points, it):\n",
    "    pred_points = kaolin.ops.mesh.sample_points(mesh_verts.unsqueeze(0), mesh_faces, 50000)[0][0]\n",
    "    chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), points.unsqueeze(0)).mean()\n",
    "    if it > iterations//2:\n",
    "        lap = laplace_regularizer_const(mesh_verts, mesh_faces)\n",
    "        return chamfer + lap * laplacian_weight\n",
    "    return chamfer\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f96974c",
   "metadata": {},
   "source": [
    "# Setting up Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a5d4a42f",
   "metadata": {},
   "outputs": [],
   "source": [
    "vars = [p for _, p in model.named_parameters()]\n",
    "optimizer = torch.optim.Adam(vars, lr=lr)\n",
    "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7917ee1",
   "metadata": {},
   "source": [
    "# Training\n",
    "\n",
    "At every iteration, we first predict SDF and deviation vector at each vertex with the network. Next, we extract the triangular mesh by running Marching Tetrahedra on the grid. We then compute loss functions on the extracted mesh and backpropagate gradient to the network weights. Notice that the topology of the mesh is changing during training, as shown in the output message. The training takes ~5 minutes on a TITAN RTX GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "583bec8b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 0 - loss: 0.02473130077123642, # of mesh vertices: 18110, # of mesh faces: 36216\n",
      "Iteration 100 - loss: 0.002605137648060918, # of mesh vertices: 24234, # of mesh faces: 48464\n",
      "Iteration 200 - loss: 0.0003765518486034125, # of mesh vertices: 26862, # of mesh faces: 53720\n",
      "Iteration 300 - loss: 0.0010241996496915817, # of mesh vertices: 31508, # of mesh faces: 63012\n",
      "Iteration 400 - loss: 0.0001085952389985323, # of mesh vertices: 28300, # of mesh faces: 56596\n",
      "Iteration 500 - loss: 7.9919038398657e-05, # of mesh vertices: 28710, # of mesh faces: 57416\n",
      "Iteration 600 - loss: 0.00010018410830525681, # of mesh vertices: 27400, # of mesh faces: 54796\n",
      "Iteration 700 - loss: 6.0749654949177057e-05, # of mesh vertices: 28494, # of mesh faces: 56984\n",
      "Iteration 800 - loss: 0.0002924088039435446, # of mesh vertices: 27660, # of mesh faces: 55316\n",
      "Iteration 900 - loss: 9.263768151868135e-05, # of mesh vertices: 28512, # of mesh faces: 57020\n",
      "Iteration 1000 - loss: 7.250437192851678e-05, # of mesh vertices: 28598, # of mesh faces: 57192\n",
      "Iteration 1100 - loss: 6.00546263740398e-05, # of mesh vertices: 28352, # of mesh faces: 56700\n",
      "Iteration 1200 - loss: 4.965237167198211e-05, # of mesh vertices: 28606, # of mesh faces: 57208\n",
      "Iteration 1300 - loss: 4.5047825551591814e-05, # of mesh vertices: 28934, # of mesh faces: 57864\n",
      "Iteration 1400 - loss: 4.2731968278530985e-05, # of mesh vertices: 28878, # of mesh faces: 57752\n",
      "Iteration 1500 - loss: 8.582305599702522e-05, # of mesh vertices: 28790, # of mesh faces: 57576\n",
      "Iteration 1600 - loss: 4.140706005273387e-05, # of mesh vertices: 28924, # of mesh faces: 57844\n",
      "Iteration 1700 - loss: 3.995447332272306e-05, # of mesh vertices: 28850, # of mesh faces: 57696\n",
      "Iteration 1800 - loss: 3.944659692933783e-05, # of mesh vertices: 29064, # of mesh faces: 58128\n",
      "Iteration 1900 - loss: 3.890909647452645e-05, # of mesh vertices: 28994, # of mesh faces: 57984\n",
      "Iteration 2000 - loss: 3.9877151721157134e-05, # of mesh vertices: 28832, # of mesh faces: 57660\n",
      "Iteration 2100 - loss: 3.8087084249127656e-05, # of mesh vertices: 28942, # of mesh faces: 57880\n",
      "Iteration 2200 - loss: 3.8198602851480246e-05, # of mesh vertices: 29116, # of mesh faces: 58228\n",
      "Iteration 2300 - loss: 3.789698894252069e-05, # of mesh vertices: 29188, # of mesh faces: 58372\n",
      "Iteration 2400 - loss: 3.733349876711145e-05, # of mesh vertices: 28986, # of mesh faces: 57968\n",
      "Iteration 2500 - loss: 3.886773993144743e-05, # of mesh vertices: 28728, # of mesh faces: 57452\n",
      "Iteration 2600 - loss: 3.7754220102215186e-05, # of mesh vertices: 29132, # of mesh faces: 58260\n",
      "Iteration 2700 - loss: 3.751121403183788e-05, # of mesh vertices: 28962, # of mesh faces: 57920\n",
      "Iteration 2800 - loss: 3.733678022399545e-05, # of mesh vertices: 28942, # of mesh faces: 57880\n",
      "Iteration 2900 - loss: 3.712274701683782e-05, # of mesh vertices: 28970, # of mesh faces: 57936\n",
      "Iteration 3000 - loss: 3.738816667464562e-05, # of mesh vertices: 29154, # of mesh faces: 58304\n",
      "Iteration 3100 - loss: 3.6861980333924294e-05, # of mesh vertices: 29090, # of mesh faces: 58176\n",
      "Iteration 3200 - loss: 3.7955178413540125e-05, # of mesh vertices: 29228, # of mesh faces: 58452\n",
      "Iteration 3300 - loss: 3.692376412800513e-05, # of mesh vertices: 28990, # of mesh faces: 57976\n",
      "Iteration 3400 - loss: 3.6803434340981767e-05, # of mesh vertices: 29032, # of mesh faces: 58060\n",
      "Iteration 3500 - loss: 3.666708289529197e-05, # of mesh vertices: 29006, # of mesh faces: 58008\n",
      "Iteration 3600 - loss: 3.6867546441499144e-05, # of mesh vertices: 28916, # of mesh faces: 57828\n",
      "Iteration 3700 - loss: 3.673196624731645e-05, # of mesh vertices: 28876, # of mesh faces: 57748\n",
      "Iteration 3800 - loss: 3.683008617372252e-05, # of mesh vertices: 28868, # of mesh faces: 57732\n",
      "Iteration 3900 - loss: 3.696472413139418e-05, # of mesh vertices: 28932, # of mesh faces: 57860\n",
      "Iteration 4000 - loss: 3.699162698467262e-05, # of mesh vertices: 29188, # of mesh faces: 58372\n",
      "Iteration 4100 - loss: 3.622782969614491e-05, # of mesh vertices: 28980, # of mesh faces: 57956\n",
      "Iteration 4200 - loss: 3.6102632293477654e-05, # of mesh vertices: 28990, # of mesh faces: 57976\n",
      "Iteration 4300 - loss: 3.6840694519924e-05, # of mesh vertices: 28888, # of mesh faces: 57772\n",
      "Iteration 4400 - loss: 3.603967707022093e-05, # of mesh vertices: 28992, # of mesh faces: 57980\n",
      "Iteration 4500 - loss: 3.609260966186412e-05, # of mesh vertices: 29044, # of mesh faces: 58084\n",
      "Iteration 4600 - loss: 3.623321754275821e-05, # of mesh vertices: 29112, # of mesh faces: 58220\n",
      "Iteration 4700 - loss: 3.591994391172193e-05, # of mesh vertices: 29116, # of mesh faces: 58228\n",
      "Iteration 4800 - loss: 3.641782677732408e-05, # of mesh vertices: 29148, # of mesh faces: 58292\n",
      "Iteration 4900 - loss: 3.601510252337903e-05, # of mesh vertices: 29078, # of mesh faces: 58152\n",
      "Iteration 4999 - loss: 3.580914199119434e-05, # of mesh vertices: 29056, # of mesh faces: 58108\n"
     ]
    }
   ],
   "source": [
    "for it in range(iterations):\n",
    "    pred = model(tet_verts) # predict SDF and per-vertex deformation\n",
    "    sdf, deform = pred[:,0], pred[:,1:]\n",
    "    verts_deformed = tet_verts + torch.tanh(deform) / grid_res # constraint deformation to avoid flipping tets\n",
    "    mesh_verts, mesh_faces = kaolin.ops.conversions.marching_tetrahedra(verts_deformed.unsqueeze(0), tets, sdf.unsqueeze(0)) # running MT (batched) to extract surface mesh\n",
    "    mesh_verts, mesh_faces = mesh_verts[0], mesh_faces[0]\n",
    "\n",
    "    loss = loss_f(mesh_verts, mesh_faces, points, it)\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    if (it) % save_every == 0 or it == (iterations - 1): \n",
    "        print ('Iteration {} - loss: {}, # of mesh vertices: {}, # of mesh faces: {}'.format(it, loss, mesh_verts.shape[0], mesh_faces.shape[0]))\n",
    "        # save reconstructed mesh\n",
    "        timelapse.add_mesh_batch(\n",
    "            iteration=it+1,\n",
    "            category='extracted_mesh',\n",
    "            vertices_list=[mesh_verts.cpu()],\n",
    "            faces_list=[mesh_faces.cpu()]\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e49e8f67",
   "metadata": {},
   "source": [
    "# Visualize Training\n",
    "\n",
    "You can now use [the Omniverse app](https://docs.omniverse.nvidia.com/app_kaolin/app_kaolin) to visualize the mesh optimization over training by using the training visualizer on \"./logs/\", where we stored the checkpoints.\n",
    "\n",
    "Alternatively, you can use [kaolin-dash3d](https://kaolin.readthedocs.io/en/latest/notes/checkpoints.html?highlight=usd#visualizing-with-kaolin-dash3d) to visualize the checkpoint by running <code>kaolin-dash3d --logdir=$logs_path --port=8080</code>. This command will launch a web server that will stream geometry to web clients. You can view the input point cloud and the reconstructed mesh at [localhost:8080](localhost:8080) as shown below. You can change the *global iteration* on the left to see how the mesh evolves during training. \n",
    "\n",
    "![alt text](../samples/dash3d_mesh.png \"Title\")\n",
    "![alt text](../samples/dash3d_pcd.png \"Title\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4040fd28a16387d31474220157706b1752bd7f86ecfd14350c5c940438c26826"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
